# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

from pathlib import Path

import numpy as np
import pytest
from numpy.testing import assert_allclose, assert_array_equal, assert_array_less
from scipy.interpolate import interp1d
from scipy.spatial.distance import cdist

from mne import pick_info, pick_types
from mne._fiff.constants import FIFF
from mne.chpi import (
    _chpi_locs_to_times_dig,
    _compute_good_distances,
    _get_hpi_initial_fit,
    _setup_ext_proj,
    compute_chpi_amplitudes,
    compute_chpi_locs,
    compute_chpi_snr,
    compute_head_pos,
    extract_chpi_locs_ctf,
    extract_chpi_locs_kit,
    filter_chpi,
    get_active_chpi,
    get_chpi_info,
    head_pos_to_trans_rot_t,
    read_head_pos,
    refit_hpi,
    write_head_pos,
)
from mne.datasets import testing
from mne.forward._compute_forward import _MAG_FACTOR
from mne.io import (
    RawArray,
    read_info,
    read_raw_artemis123,
    read_raw_ctf,
    read_raw_fif,
    read_raw_kit,
)
from mne.simulation import add_chpi
from mne.transforms import (
    _angle_between_quats,
    angle_distance_between_rigid,
    rot_to_quat,
)
from mne.utils import (
    _record_warnings,
    assert_meg_snr,
    catch_logging,
    object_diff,
    verbose,
)
from mne.utils._testing import assert_trans_allclose
from mne.viz import plot_head_positions

base_dir = Path(__file__).parents[1] / "io" / "tests" / "data"
ctf_fname = base_dir / "test_ctf_raw.fif"
hp_fif_fname = base_dir / "test_chpi_raw_sss.fif"
raw_fname = base_dir / "test_raw.fif"

data_path = testing.data_path(download=False)
sample_fname = data_path / "MEG" / "sample" / "sample_audvis_trunc_raw.fif"
chpi_fif_fname = data_path / "SSS" / "test_move_anon_raw.fif"
pos_fname = data_path / "SSS" / "test_move_anon_raw.pos"
sss_fif_fname = data_path / "SSS" / "test_move_anon_raw_sss.fif"
sss_hpisubt_fname = data_path / "SSS" / "test_move_anon_hpisubt_raw.fif"
chpi5_fif_fname = data_path / "SSS" / "chpi5_raw.fif"
chpi5_pos_fname = data_path / "SSS" / "chpi5_raw_mc.pos"
ctf_chpi_fname = data_path / "CTF" / "testdata_ctf_mc.ds"
ctf_chpi_pos_fname = data_path / "CTF" / "testdata_ctf_mc.pos"
chpi_problem_fname = data_path / "SSS" / "chpi_problematic-info.fif"

art_fname = (
    data_path
    / "ARTEMIS123"
    / "Artemis_Data_2017-04-04-15h-44m-22s_Motion_Translation-z.bin"
)
art_mc_fname = (
    data_path
    / "ARTEMIS123"
    / "Artemis_Data_2017-04-04-15h-44m-22s_Motion_Translation-z_mc.pos"
)

con_fname = data_path / "KIT" / "MQKIT_125_2sec.con"
mrk_fname = data_path / "KIT" / "MQKIT_125.mrk"
elp_fname = data_path / "KIT" / "MQKIT_125.elp"
hsp_fname = data_path / "KIT" / "MQKIT_125.hsp"
berlin_fname = data_path / "KIT" / "data_berlin.con"


@testing.requires_testing_data
def test_chpi_adjust():
    """Test cHPI logging and adjustment."""
    raw = read_raw_fif(chpi_fif_fname, allow_maxshield="yes")
    with catch_logging() as log:
        _get_hpi_initial_fit(raw.info, adjust=True, verbose="debug")
        get_chpi_info(raw.info, on_missing="raise", verbose="debug")
    # Ran MaxFilter (with -list, -v, -movecomp, etc.), and got:
    msg = [
        "HPIFIT: 5 coils digitized in order 5 1 4 3 2",
        "HPIFIT: 3 coils accepted: 1 2 4",
        "HPI coil moments (3, 5):",
        "2.08542e-15 -1.52486e-15 -1.53484e-15",
        "2.14516e-15 2.09608e-15 7.30303e-16",
        "-3.2318e-16 -4.25666e-16 2.69997e-15",
        "5.21717e-16 1.28406e-15 1.95335e-15",
        "1.21199e-15 -1.25801e-19 1.18321e-15",
        "HPIFIT errors:  0.3, 0.3, 5.3, 0.4, 3.2 mm.",
        "HPI consistency of isotrak and hpifit is OK.",
        "HP fitting limits: err = 5.0 mm, gval = 0.980.",
        "Using 5 HPI coils: 83 143 203 263 323 Hz",  # actually came earlier
    ]

    log = log.getvalue().splitlines()
    assert set(log) == set(msg), "\n" + "\n".join(set(msg) - set(log))

    # Then took the raw file, did this:
    raw.info["dig"][5]["r"][2] += 1.0
    # And checked the result in MaxFilter, which changed the logging as:
    msg = (
        msg[:8]
        + [
            "HPIFIT errors:  0.3, 0.3, 5.3, 999.7, 3.2 mm.",
            "Note: HPI coil 3 isotrak is adjusted by 5.3 mm!",
            "Note: HPI coil 5 isotrak is adjusted by 3.2 mm!",
        ]
        + msg[-2:]
    )
    with catch_logging() as log:
        _get_hpi_initial_fit(raw.info, adjust=True, verbose="debug")
        get_chpi_info(raw.info, on_missing="raise", verbose="debug")
    log = log.getvalue().splitlines()
    assert set(log) == set(msg), "\n" + "\n".join(set(msg) - set(log))


@testing.requires_testing_data
def test_read_write_head_pos(tmp_path):
    """Test reading and writing head position quaternion parameters."""
    temp_name = tmp_path / "temp.pos"
    # This isn't a 100% valid quat matrix but it should be okay for tests
    head_pos_rand = np.random.RandomState(0).randn(20, 10)
    # This one is valid
    head_pos_read = read_head_pos(pos_fname)
    for head_pos_orig in (head_pos_rand, head_pos_read):
        write_head_pos(temp_name, head_pos_orig)
        head_pos = read_head_pos(temp_name)
        assert_allclose(head_pos_orig, head_pos, atol=1e-3)
    # Degenerate cases
    pytest.raises(TypeError, write_head_pos, 0, head_pos_read)  # not filename
    pytest.raises(ValueError, write_head_pos, temp_name, "foo")  # not array
    pytest.raises(ValueError, write_head_pos, temp_name, head_pos_read[:, :9])
    pytest.raises(TypeError, read_head_pos, 0)
    pytest.raises(OSError, read_head_pos, "101")


@testing.requires_testing_data
def test_hpi_info(tmp_path):
    """Test getting HPI info."""
    temp_name = tmp_path / "temp_raw.fif"
    for fname in (chpi_fif_fname, sss_fif_fname):
        raw = read_raw_fif(fname, allow_maxshield="yes").crop(0, 0.1)
        assert len(raw.info["hpi_subsystem"]) > 0
        raw.save(temp_name, overwrite=True)
        info = read_info(temp_name)
        assert len(info["hpi_subsystem"]) == len(raw.info["hpi_subsystem"])

    # test get_chpi_info()
    info = read_info(chpi_fif_fname)
    hpi_freqs, stim_ch_idx, hpi_on_codes = get_chpi_info(info)

    assert_allclose(hpi_freqs, np.array([83.0, 143.0, 203.0, 263.0, 323.0]))
    assert stim_ch_idx == 378
    assert_allclose(hpi_on_codes, np.array([256, 512, 1024, 2048, 4096]))

    # test get_chpi_info() if no proper cHPI info is available
    with info._unlock():
        info["hpi_subsystem"] = None
        info["hpi_meas"] = []
        info["hpi_results"] = []

    with pytest.raises(ValueError, match="No appropriate cHPI information"):
        get_chpi_info(info)

    with pytest.warns(RuntimeWarning, match="No appropriate cHPI information"):
        get_chpi_info(info, on_missing="warn")

    hpi_freqs, stim_ch_idx, hpi_on_codes = get_chpi_info(info, on_missing="ignore")
    assert_array_equal([], hpi_freqs)
    assert stim_ch_idx is None
    assert_array_equal([], hpi_on_codes)


def _assert_quats(
    actual,
    desired,
    dist_tol=0.003,
    angle_tol=5.0,
    err_rtol=0.5,
    gof_rtol=0.001,
    vel_atol=2e-3,
):  # 2 mm/s
    """Compare estimated cHPI positions."""
    __tracebackhide__ = True
    trans_est, rot_est, t_est = head_pos_to_trans_rot_t(actual)
    trans, rot, t = head_pos_to_trans_rot_t(desired)
    quats_est = rot_to_quat(rot_est)
    gofs, errs, vels = desired[:, 7:].T
    gofs_est, errs_est, vels_est = actual[:, 7:].T
    del actual, desired

    # maxfilter produces some times that are implausibly large (weird)
    if not np.isclose(t[0], t_est[0], atol=1e-1):  # within 100 ms
        raise AssertionError(
            f"Start times not within 100 ms: {t[0]:0.3f} != {t_est[0]:0.3f}"
        )
    use_mask = (t >= t_est[0]) & (t <= t_est[-1])
    t = t[use_mask]
    trans = trans[use_mask]
    quats = rot_to_quat(rot)
    quats = quats[use_mask]
    gofs, errs, vels = gofs[use_mask], errs[use_mask], vels[use_mask]

    # double-check our angle function
    for q in (quats, quats_est):
        angles = _angle_between_quats(q, q)
        assert_allclose(angles, 0.0, atol=1e-5)

    # limit translation difference between MF and our estimation
    trans_est_interp = interp1d(t_est, trans_est, axis=0)(t)
    distances = np.sqrt(np.sum((trans - trans_est_interp) ** 2, axis=1))
    assert np.isfinite(distances).all()
    arg_worst = np.argmax(distances)
    assert distances[arg_worst] <= dist_tol, (
        f"@ {t[arg_worst]:0.3f} seconds: "
        f"{1000 * distances[arg_worst]:0.3f} > {1000 * dist_tol:0.3f} mm"
    )

    # limit rotation difference between MF and our estimation
    # (note that the interpolation will make this slightly worse)
    quats_est_interp = interp1d(t_est, quats_est, axis=0)(t)
    angles = 180 * _angle_between_quats(quats_est_interp, quats) / np.pi
    arg_worst = np.argmax(angles)
    assert angles[arg_worst] <= angle_tol, (
        f"@ {t[arg_worst]:0.3f} seconds: "
        f"{angles[arg_worst]:0.3f} > {angle_tol:0.3f} deg"
    )

    # error calculation difference
    errs_est_interp = interp1d(t_est, errs_est)(t)
    assert_allclose(
        errs_est_interp, errs, rtol=err_rtol, atol=1e-3, err_msg="err"
    )  # 1 mm

    # gof calculation difference
    gof_est_interp = interp1d(t_est, gofs_est)(t)
    assert_allclose(gof_est_interp, gofs, rtol=gof_rtol, atol=1e-7, err_msg="gof")

    # velocity calculation difference
    vel_est_interp = interp1d(t_est, vels_est)(t)
    assert_allclose(vel_est_interp, vels, atol=vel_atol, err_msg="velocity")


def _decimate_chpi(raw, decim=4):
    """Decimate raw data (with aliasing) in cHPI-fitting compatible way."""
    raw_dec = RawArray(
        raw._data[:, ::decim], raw.info, first_samp=raw.first_samp // decim
    )
    with raw_dec.info._unlock():
        raw_dec.info["sfreq"] /= decim
    for coil in raw_dec.info["hpi_meas"][0]["hpi_coils"]:
        if coil["coil_freq"] > raw_dec.info["sfreq"]:
            coil["coil_freq"] = np.mod(coil["coil_freq"], raw_dec.info["sfreq"])
            if coil["coil_freq"] > raw_dec.info["sfreq"] / 2.0:
                coil["coil_freq"] = raw_dec.info["sfreq"] - coil["coil_freq"]
    return raw_dec


# A shortcut method for testing that does both steps
@verbose
def _calculate_chpi_positions(
    raw,
    t_step_min=0.01,
    t_step_max=1.0,
    t_window="auto",
    too_close="raise",
    dist_limit=0.005,
    gof_limit=0.98,
    ext_order=1,
    verbose=None,
):
    chpi_amplitudes = compute_chpi_amplitudes(
        raw,
        t_step_min=t_step_min,
        t_window=t_window,
        ext_order=ext_order,
        verbose=verbose,
    )
    chpi_locs = compute_chpi_locs(
        raw.info,
        chpi_amplitudes,
        t_step_max=t_step_max,
        too_close=too_close,
        verbose=verbose,
    )
    head_pos = compute_head_pos(
        raw.info, chpi_locs, dist_limit=dist_limit, gof_limit=gof_limit, verbose=verbose
    )
    return head_pos


@testing.requires_testing_data
def test_calculate_chpi_positions_preload():
    """Test calculation of cHPI positions with and without data loaded."""
    raw = read_raw_fif(chpi_fif_fname, allow_maxshield="yes").crop(0, 2)
    kwargs = dict(t_step_min=0.1, t_window="auto", verbose=True)
    pos = compute_chpi_amplitudes(raw, **kwargs)
    raw.load_data()
    pos_preload = compute_chpi_amplitudes(raw, **kwargs)
    assert object_diff(pos, pos_preload) == ""


@pytest.mark.slowtest
@testing.requires_testing_data
def test_calculate_chpi_positions_vv():
    """Test calculation of cHPI positions."""
    # Check to make sure our fits match MF decently
    mf_quats = read_head_pos(pos_fname)
    raw = read_raw_fif(chpi_fif_fname, allow_maxshield="yes")
    raw.crop(0, 5).load_data()
    # check "auto" t_window estimation at full sampling rate
    with catch_logging() as log:
        compute_chpi_amplitudes(
            raw, t_step_min=0.1, t_window="auto", tmin=0, tmax=2, verbose=True
        )
    assert "83.3 ms" in log.getvalue()
    # This is a little hack (aliasing while decimating) to make it much faster
    # for testing purposes only. We can relax this later if we find it breaks
    # something.
    raw_dec = _decimate_chpi(raw, 15)
    with catch_logging() as log:
        with pytest.warns(RuntimeWarning, match="cannot determine"):
            py_quats = _calculate_chpi_positions(raw_dec, t_window=0.2, verbose="debug")
    log = log.getvalue()
    assert "\nHPIFIT" in log
    assert "Computing 4385 HPI location guesses" in log
    _assert_quats(py_quats, mf_quats, dist_tol=0.001, angle_tol=0.7)
    # degenerate conditions
    raw_no_chpi = read_raw_fif(sample_fname)
    with pytest.raises(ValueError, match="No appropriate cHPI information"):
        _calculate_chpi_positions(raw_no_chpi)
    raw_bad = raw.copy()
    del raw_bad.info["hpi_meas"][0]["hpi_coils"][0]["coil_freq"]
    with pytest.raises(ValueError, match="No appropriate cHPI information"):
        _calculate_chpi_positions(raw_bad)
    raw_bad = raw.copy()
    for d in raw_bad.info["dig"]:
        if d["kind"] == FIFF.FIFFV_POINT_HPI:
            d["coord_frame"] = FIFF.FIFFV_COORD_UNKNOWN
            break
    with pytest.raises(RuntimeError, match="coordinate frame incorrect"):
        _calculate_chpi_positions(raw_bad)
    for d in raw_bad.info["dig"]:
        if d["kind"] == FIFF.FIFFV_POINT_HPI:
            d["coord_frame"] = FIFF.FIFFV_COORD_HEAD
            d["r"] = np.ones(3)
    raw_bad.crop(0, 1.0)
    picks = np.concatenate(
        [
            np.arange(306, len(raw_bad.ch_names)),
            pick_types(raw_bad.info, meg=True)[::16],
        ]
    )
    raw_bad.pick([raw_bad.ch_names[pick] for pick in picks])
    with _record_warnings(), pytest.warns(RuntimeWarning, match="Discrepancy"):
        with catch_logging() as log_file:
            _calculate_chpi_positions(raw_bad, t_step_min=1.0, verbose=True)
    # ignore HPI info header and [done] footer
    assert "0/5 good HPI fits" in log_file.getvalue()

    # half the rate cuts off cHPI coils
    with raw.info._unlock():
        raw.info["lowpass"] /= 2.0
    with pytest.raises(RuntimeError, match="above the"):
        _calculate_chpi_positions(raw)


@testing.requires_testing_data
@pytest.mark.slowtest
def test_calculate_chpi_snr():
    """Test cHPI SNR calculation."""
    raw = read_raw_fif(chpi_fif_fname, allow_maxshield="yes")
    # include handling of NaN (when cHPI was off at the beginning)
    raw.load_data()
    raw.info["bads"] = ["MEG0342", "MEG1443"]
    stop = int(round(raw.info["sfreq"])) * 2
    raw._data[raw.ch_names.index("STI201"), :stop] = 0

    result = compute_chpi_snr(raw)
    # make sure all the entries are there
    keys = {
        f"{ch_type}_{key}"
        for ch_type in ("mag", "grad")
        for key in ("snr", "power", "resid")
    }
    assert set(result) == keys.union({"times", "freqs"})
    # make sure the values are plausible, given the sample data file
    n_pts = len(raw.times) // int(round(raw.info["sfreq"] * 0.01))
    # our logic in this test for this length is not perfect
    assert_allclose(result["mag_snr"].shape[0], n_pts, atol=5)
    n_nan = np.where(result["times"] <= raw.first_time + 2)[0][-1]
    assert_allclose(result["mag_snr"][:n_nan], np.nan)
    assert result["mag_snr"][n_nan:].min() > 1
    assert result["mag_snr"][n_nan:].max() < 40
    assert result["grad_snr"][n_nan:].min() > 1
    assert result["grad_snr"][n_nan:].max() < 40


@testing.requires_testing_data
@pytest.mark.slowtest
def test_calculate_chpi_positions_artemis():
    """Test on 5k artemis data."""
    raw = read_raw_artemis123(art_fname, preload=True)
    mf_quats = read_head_pos(art_mc_fname)
    mf_quats[:, 8:] /= 100  # old code errantly had this factor
    py_quats = _calculate_chpi_positions(raw, t_step_min=2.0, verbose="debug")
    _assert_quats(
        py_quats, mf_quats, dist_tol=0.001, angle_tol=1.0, err_rtol=0.7, vel_atol=1e-2
    )


@testing.requires_testing_data
def test_warn_maxwell_filtered():
    """Test that trying to compute locations on Maxwell filtered data warns."""
    raw = read_raw_fif(sss_fif_fname).crop(0, 1)
    with pytest.warns(RuntimeWarning, match="Maxwell filter"):
        amps = compute_chpi_amplitudes(raw)
    assert len(amps["times"]) > 0  # but for this file, it does work!


@testing.requires_testing_data
def test_initial_fit_redo():
    """Test that initial fits can be redone based on moments."""
    raw = read_raw_fif(chpi_fif_fname, allow_maxshield="yes")
    slopes = np.array([[c["slopes"] for c in raw.info["hpi_meas"][0]["hpi_coils"]]])
    amps = np.linalg.norm(slopes, axis=-1)
    amps /= slopes.shape[-1]
    assert_array_less(amps, 5e-11)
    assert_array_less(1e-12, amps)
    proj, _, _ = _setup_ext_proj(raw.info, ext_order=1)
    chpi_amplitudes = dict(times=np.zeros(1), slopes=slopes, proj=proj)
    chpi_locs = compute_chpi_locs(raw.info, chpi_amplitudes)

    # check GOF
    coil_gof = raw.info["hpi_results"][0]["goodness"]
    assert_allclose(chpi_locs["gofs"][0], coil_gof, atol=0.3)  # XXX not good

    # check moment
    # XXX our forward and theirs differ by an extra mult by _MAG_FACTOR
    coil_moment = raw.info["hpi_results"][0]["moments"] / _MAG_FACTOR
    py_moment = chpi_locs["moments"][0]
    coil_amp = np.linalg.norm(coil_moment, axis=-1, keepdims=True)
    py_amp = np.linalg.norm(py_moment, axis=-1, keepdims=True)
    assert_allclose(coil_amp, py_amp, rtol=0.2)
    coil_ori = coil_moment / coil_amp
    py_ori = py_moment / py_amp
    angles = np.rad2deg(np.arccos(np.abs(np.sum(coil_ori * py_ori, axis=1))))
    assert_array_less(angles, 20)

    # check resulting dev_head_t
    head_pos = compute_head_pos(raw.info, chpi_locs)
    assert head_pos.shape == (1, 10)
    nm_pos = raw.info["dev_head_t"]["trans"]
    dist = 1000 * np.linalg.norm(nm_pos[:3, 3] - head_pos[0, 4:7])
    assert 0.1 < dist < 2
    angle = np.rad2deg(
        _angle_between_quats(rot_to_quat(nm_pos[:3, :3]), head_pos[0, 1:4])
    )
    assert 0.1 < angle < 2
    gof = head_pos[0, 7]
    assert_allclose(gof, 0.9999, atol=1e-4)


@testing.requires_testing_data
def test_calculate_head_pos_chpi_on_chpi5_in_one_second_steps():
    """Comparing estimated cHPI positions with MF results (one second)."""
    # Check to make sure our fits match MF decently
    mf_quats = read_head_pos(chpi5_pos_fname)
    raw = read_raw_fif(chpi5_fif_fname, allow_maxshield="yes")
    # the last two seconds contain a maxfilter problem!
    # fiff file timing: 26. to 43. s
    # maxfilter estimates a wrong head position for interval 16: 41.-42. s
    raw = _decimate_chpi(raw.crop(0.0, 10.0).load_data(), decim=8)
    # needs no interpolation, because maxfilter pos files comes with 1 s steps
    py_quats = _calculate_chpi_positions(
        raw, t_step_min=1.0, t_step_max=1.0, t_window=1.0, verbose="debug"
    )
    _assert_quats(
        py_quats, mf_quats, dist_tol=0.002, angle_tol=1.2, vel_atol=3e-3
    )  # 3 mm/s


@pytest.mark.slowtest
@testing.requires_testing_data
def test_calculate_head_pos_chpi_on_chpi5_in_shorter_steps():
    """Comparing estimated cHPI positions with MF results (smaller steps)."""
    # Check to make sure our fits match MF decently
    mf_quats = read_head_pos(chpi5_pos_fname)
    raw = read_raw_fif(chpi5_fif_fname, allow_maxshield="yes")
    raw = _decimate_chpi(raw.crop(0.0, 5.0).load_data(), decim=8)
    with pytest.warns(RuntimeWarning, match="cannot determine"):
        py_quats = _calculate_chpi_positions(
            raw, t_step_min=0.1, t_step_max=0.1, t_window=0.1, verbose="debug"
        )
    # needs interpolation, tolerance must be increased
    _assert_quats(
        py_quats, mf_quats, dist_tol=0.002, angle_tol=1.2, vel_atol=0.02
    )  # 2 cm/s is not great but probably fine


def test_simulate_calculate_head_pos_chpi():
    """Test calculation of cHPI positions with simulated data."""
    # Read info dict from raw FIF file
    info = read_info(raw_fname)
    # Tune the info structure
    chpi_channel = "STI201"
    ncoil = len(info["hpi_results"][0]["order"])
    coil_freq = 10 + np.arange(ncoil) * 5
    hpi_subsystem = {
        "event_channel": chpi_channel,
        "hpi_coils": [
            {"event_bits": np.array([256, 0, 256, 256], dtype=np.int32)},
            {"event_bits": np.array([512, 0, 512, 512], dtype=np.int32)},
            {"event_bits": np.array([1024, 0, 1024, 1024], dtype=np.int32)},
            {"event_bits": np.array([2048, 0, 2048, 2048], dtype=np.int32)},
        ],
        "ncoil": ncoil,
    }

    with info._unlock():
        info["hpi_subsystem"] = hpi_subsystem
        for fi, freq in enumerate(coil_freq):
            info["hpi_meas"][0]["hpi_coils"][fi]["coil_freq"] = freq
        info["sfreq"] = 100.0  # this will speed it up a lot
    picks = pick_types(info, meg=True, stim=True, eeg=False, exclude=[])
    info = pick_info(info, picks)
    info["chs"][info["ch_names"].index("STI 001")]["ch_name"] = "STI201"
    info._update_redundant()
    with info._unlock():
        info["projs"] = []

    info_trans = info["dev_head_t"]["trans"].copy()

    dev_head_pos_ini = np.concatenate(
        [rot_to_quat(info_trans[:3, :3]), info_trans[:3, 3]]
    )
    ez = np.array([0, 0, 1])  # Unit vector in z-direction of head coordinates

    # Define some constants
    duration = 10  # Time / s

    # Quotient of head position sampling frequency
    # and raw sampling frequency
    head_pos_sfreq_quotient = 0.01

    # Round number of head positions to the next integer
    S = int(duration * info["sfreq"] * head_pos_sfreq_quotient)
    assert S == 10
    dz = 0.001  # Shift in z-direction is 0.1mm for each step

    dev_head_pos = np.zeros((S, 10))
    dev_head_pos[:, 0] = np.arange(S) * info["sfreq"] * head_pos_sfreq_quotient
    dev_head_pos[:, 1:4] = dev_head_pos_ini[:3]
    dev_head_pos[:, 4:7] = dev_head_pos_ini[3:] + np.outer(np.arange(S) * dz, ez)
    dev_head_pos[:, 7] = 1.0

    # m/s
    dev_head_pos[:, 9] = dz / (info["sfreq"] * head_pos_sfreq_quotient)

    # Round number of samples to the next integer
    raw_data = np.zeros((len(picks), int(duration * info["sfreq"] + 0.5)))
    raw = RawArray(raw_data, info)
    add_chpi(raw, dev_head_pos)
    quats = _calculate_chpi_positions(
        raw,
        t_step_min=raw.info["sfreq"] * head_pos_sfreq_quotient,
        t_step_max=raw.info["sfreq"] * head_pos_sfreq_quotient,
        t_window=1.0,
    )
    _assert_quats(
        quats, dev_head_pos, dist_tol=0.001, angle_tol=1.0, vel_atol=4e-3
    )  # 4 mm/s


def _calculate_chpi_coil_locs(raw, verbose):
    """Wrap to facilitate change diff."""
    chpi_amplitudes = compute_chpi_amplitudes(raw, verbose=verbose)
    chpi_locs = compute_chpi_locs(raw.info, chpi_amplitudes, verbose=verbose)
    return _chpi_locs_to_times_dig(chpi_locs)


def _check_dists(info, cHPI_digs, n_bad=0, bad_low=0.02, bad_high=0.04):
    __tracebackhide__ = True
    orig = _get_hpi_initial_fit(info)
    hpi_coil_distances = cdist(orig, orig)
    new_pos = np.array([d["r"] for d in cHPI_digs])
    mask, distances = _compute_good_distances(hpi_coil_distances, new_pos)
    good_idx = np.where(mask)[0]
    assert len(good_idx) >= 3
    meds = np.empty(len(orig))
    for ii in range(len(orig)):
        idx = np.setdiff1d(good_idx, ii)
        meds[ii] = np.median(distances[ii][idx])
    meds = np.array(meds)
    assert_array_less(meds[good_idx], 0.003)
    bad_idx = np.where(~mask)[0]
    if len(bad_idx):
        bads = meds[bad_idx]
        assert_array_less(bad_low, bads)
        assert_array_less(bads, bad_high)


@pytest.mark.slowtest
@testing.requires_testing_data
def test_calculate_chpi_coil_locs_artemis():
    """Test computing just cHPI locations."""
    raw = read_raw_fif(chpi_fif_fname, allow_maxshield="yes", preload=True)
    # This is a little hack (aliasing while decimating) to make it much faster
    # for testing purposes only. We can relax this later if we find it breaks
    # something.
    raw_dec = _decimate_chpi(raw, 15)
    times, cHPI_digs = _calculate_chpi_coil_locs(raw_dec, verbose="debug")

    # spot check
    assert_allclose(times[0], 9.0, atol=1e-2)
    assert_allclose(
        cHPI_digs[0][2]["r"], [-0.01937833, 0.00346804, 0.06331209], atol=1e-3
    )
    assert_allclose(cHPI_digs[0][2]["gof"], 0.9957, atol=1e-3)

    assert_allclose(cHPI_digs[0][4]["r"], [-0.0655, 0.0755, 0.0004], atol=3e-3)
    assert_allclose(cHPI_digs[0][4]["gof"], 0.9323, atol=1e-3)
    _check_dists(raw.info, cHPI_digs[0], n_bad=1)

    # test on 5k artemis data
    raw = read_raw_artemis123(art_fname, preload=True)
    times, cHPI_digs = _calculate_chpi_coil_locs(raw, verbose="debug")

    assert len(np.setdiff1d(times, raw.times + raw.first_time)) == 0
    # Should be somewhere around 1.5 s, depending on coil GOF values
    # around 0.98 it can change
    assert_allclose(times[5], 1.5, atol=2e-1)
    assert_allclose(cHPI_digs[5][0]["gof"], 0.995, atol=5e-3)
    assert_allclose(cHPI_digs[5][0]["r"], [-0.0157, 0.0655, 0.0018], atol=1e-3)
    _check_dists(raw.info, cHPI_digs[5])
    coil_amplitudes = compute_chpi_amplitudes(raw)
    with pytest.raises(ValueError, match="too_close"):
        compute_chpi_locs(raw.info, coil_amplitudes, too_close="foo")
    # ensure values are in a reasonable range
    amps = np.linalg.norm(coil_amplitudes["slopes"], axis=-1)
    amps /= coil_amplitudes["slopes"].shape[-1]
    assert amps.shape == (len(coil_amplitudes["times"]), 3)
    assert_array_less(amps, 1e-11)
    assert_array_less(1e-13, amps)
    # with nan amplitudes (i.e., cHPI off) it should return an empty array,
    # but still one that is 3D
    coil_amplitudes["slopes"].fill(np.nan)
    chpi_locs = compute_chpi_locs(raw.info, coil_amplitudes)
    assert chpi_locs["rrs"].shape == (0, 3, 3)
    pos = compute_head_pos(raw.info, chpi_locs)
    assert pos.shape == (0, 10)


def assert_suppressed(new, old, suppressed, retained):
    """Assert that some frequencies are suppressed and others aren't."""
    __tracebackhide__ = True
    from scipy.signal import welch

    picks = pick_types(new.info, meg="grad")
    sfreq = new.info["sfreq"]
    new = new.get_data(picks)
    old = old.get_data(picks)
    f, new = welch(new, sfreq, "hann", nperseg=1024)
    _, old = welch(old, sfreq, "hann", nperseg=1024)
    new = np.median(new, axis=0)
    old = np.median(old, axis=0)
    for freqs, lim in ((suppressed, (10, 60)), (retained, (-3, 3))):
        for freq in freqs:
            fidx = np.argmin(np.abs(f - freq))
            this_new = np.median(new[fidx])
            this_old = np.median(old[fidx])
            suppression = -10 * np.log10(this_new / this_old)
            assert lim[0] < suppression < lim[1], freq


@testing.requires_testing_data
def test_chpi_subtraction_filter_chpi():
    """Test subtraction of cHPI signals."""
    raw = read_raw_fif(chpi_fif_fname, allow_maxshield="yes", preload=True)
    raw.info["bads"] = ["MEG0111"]
    raw.del_proj()
    raw_orig = raw.copy().crop(0, 16)
    with catch_logging() as log:
        filter_chpi(raw, include_line=False, t_window=0.2, verbose=True)
    log = log.getvalue()
    assert "No average EEG" not in log
    assert "5 cHPI" in log
    # MaxFilter doesn't do quite as well as our algorithm with the last bit
    raw.crop(0, 16)
    # remove cHPI status chans
    raw_c = read_raw_fif(sss_hpisubt_fname).crop(0, 16).load_data()
    raw_c.pick(["meg", "eeg", "eog", "ecg", "stim", "misc"])
    assert_meg_snr(raw, raw_c, 143, 624)
    # cHPI suppressed but not line freqs (or others)
    assert_suppressed(raw, raw_orig, np.arange(83, 324, 60), [30, 60, 150])
    raw = raw_orig.copy()
    with catch_logging() as log:
        filter_chpi(raw, include_line=True, t_window=0.2, verbose=True)
    log = log.getvalue()
    assert "5 cHPI" in log
    assert "6 line" in log
    # cHPI and line freqs suppressed
    suppressed = np.sort(
        np.concatenate(
            [
                np.arange(83, 324, 60),
                np.arange(60, 301, 60),
            ]
        )
    )
    assert_suppressed(raw, raw_orig, suppressed, [30, 150])

    # No HPI information
    raw = read_raw_fif(sample_fname, preload=True)
    raw_orig = raw.copy()
    assert raw.info["line_freq"] is None
    with pytest.raises(RuntimeError, match="line_freq.*consider setting it"):
        filter_chpi(raw, t_window=0.2)
    with raw.info._unlock():
        raw.info["line_freq"] = 60.0
    with pytest.raises(ValueError, match="No appropriate cHPI information"):
        filter_chpi(raw, t_window=0.2)
    # but this is allowed
    with catch_logging() as log:
        filter_chpi(raw, t_window="auto", allow_line_only=True, verbose=True)
    log = log.getvalue()
    assert "0 cHPI" in log
    assert "1 line" in log
    # Our one line freq suppressed but not others
    assert_suppressed(raw, raw_orig, [60], [30, 45, 75])

    # When MaxFliter downsamples, like::
    #     $ maxfilter -nosss -ds 2 -f test_move_anon_raw.fif \
    #           -o test_move_anon_ds2_raw.fif
    # it can strip out some values of info, which we emulate here:
    raw = read_raw_fif(chpi_fif_fname, allow_maxshield="yes")
    raw = raw.crop(0, 1).load_data().resample(600.0, npad="auto")
    with raw.info._unlock():
        raw.info["lowpass"] = 200.0
        del raw.info["maxshield"]
        del raw.info["hpi_results"][0]["moments"]
        del raw.info["hpi_subsystem"]["event_channel"]
    with catch_logging() as log:
        filter_chpi(raw, t_window="auto", verbose=True)
    with pytest.raises(ValueError, match="must be > 0"):
        filter_chpi(raw, t_window=-1)
    assert "2 cHPI" in log.getvalue()


@testing.requires_testing_data
def test_calculate_head_pos_ctf(tmp_path):
    """Test extracting of cHPI positions from CTF data."""
    raw = read_raw_ctf(ctf_chpi_fname)
    chpi_locs = extract_chpi_locs_ctf(raw)
    quats = compute_head_pos(raw.info, chpi_locs)
    mc_quats = read_head_pos(ctf_chpi_pos_fname)
    mc_quats[:, 9] /= 10000  # had old factor in there twice somehow...
    _assert_quats(
        quats, mc_quats, dist_tol=0.004, angle_tol=2.5, err_rtol=1.0, vel_atol=7e-3
    )  # 7 mm/s
    plot_head_positions(quats, info=raw.info)

    with pytest.raises(RuntimeError, match="Could not find"):
        extract_chpi_locs_ctf(read_raw_fif(ctf_fname))

    # save-load should not affect result
    fname_temp = tmp_path / "test_ctf_raw.fif"
    raw.save(fname_temp)
    raw_read = read_raw_fif(fname_temp)
    # the two attributes used by compute_head_pos
    assert_allclose(
        raw.info["dev_head_t"]["trans"], raw_read.info["dev_head_t"]["trans"]
    )
    with pytest.warns(RuntimeWarning, match="is poor"):
        head_rrs = _get_hpi_initial_fit(raw.info, verbose="debug")
    with pytest.warns(RuntimeWarning, match="is poor"):
        head_rrs_2 = _get_hpi_initial_fit(raw_read.info, verbose="debug")
    assert_allclose(head_rrs, head_rrs_2, atol=1e-5)
    quats_2 = compute_head_pos(raw_read.info, chpi_locs)
    _assert_quats(quats, quats_2, dist_tol=1e-5, angle_tol=0.1)
    chpi_locs_2 = extract_chpi_locs_ctf(raw_read)
    assert_allclose(chpi_locs["rrs"], chpi_locs_2["rrs"], atol=1e-5)
    quats_3 = compute_head_pos(raw_read.info, chpi_locs_2)
    _assert_quats(quats, quats_3, dist_tol=1e-5, angle_tol=0.1)


@testing.requires_testing_data
def test_calculate_head_pos_kit():
    """Test calculation of head position using KIT data."""
    raw = read_raw_kit(con_fname, mrk_fname, elp_fname, hsp_fname)
    assert len(raw.info["hpi_results"]) == 1
    chpi_locs = extract_chpi_locs_kit(raw)
    assert chpi_locs["rrs"].shape == (2, 5, 3)
    assert_array_less(chpi_locs["gofs"], 1.0)
    assert_array_less(0.98, chpi_locs["gofs"])
    quats = compute_head_pos(raw.info, chpi_locs)
    assert quats.shape == (2, 10)
    # plotting works
    plot_head_positions(quats, info=raw.info)
    raw_berlin = read_raw_kit(berlin_fname)
    assert "meg" in raw_berlin
    assert raw_berlin.info["dev_head_t"] is not None
    assert_allclose(raw_berlin.info["dev_head_t"]["trans"], np.eye(4))
    assert len(raw_berlin.info["hpi_results"]) == 0
    with pytest.raises(ValueError, match="Invalid value"):
        extract_chpi_locs_kit(raw_berlin)
    with pytest.raises(RuntimeError, match="not find appropriate"):
        extract_chpi_locs_kit(raw_berlin, "STI 014")
    with pytest.raises(RuntimeError, match="no initial cHPI"):
        compute_head_pos(raw_berlin.info, chpi_locs)


@testing.requires_testing_data
def test_get_active_chpi_ctf():
    """Test extracting of cHPI positions from CTF data."""
    raw = read_raw_ctf(ctf_chpi_fname)
    with pytest.raises(NotImplementedError, match="not implemented for other systems"):
        get_active_chpi(raw)


@testing.requires_testing_data
def test_get_active_chpi_neuromag():
    """Test extracting of cHPI positions from neuromag data."""
    raw = read_raw_fif(chpi_fif_fname, allow_maxshield="yes", preload=True)
    status_ch = raw.ch_names.index("STI201")

    # make artificial chpi signal
    first_three_on = 256 + 512 + 1024
    all_on = 256 + 512 + 1024 + 2048 + 4096
    raw._data[status_ch][:1000] = 0
    raw._data[status_ch][1000:2000] = first_three_on
    raw._data[status_ch][2000:] = all_on

    # build target signal
    target_signal = 5 * np.ones_like(raw.times)
    target_signal[:1000] = 0
    target_signal[1000:2000] = 3

    assert_allclose(get_active_chpi(raw, on_missing="ignore"), target_signal)

    raw_no_chpi = read_raw_fif(sample_fname)
    assert_allclose(
        get_active_chpi(raw_no_chpi, on_missing="ignore"),
        np.zeros_like(raw_no_chpi.times),
    )


def assert_slopes_correlated(actual_meas, desired_meas, *, lim=(0.99, 1.0)):
    """Assert that slopes in two coil info dicts are all close."""
    __tracebackhide__ = True
    assert len(actual_meas["hpi_coils"]) == len(desired_meas["hpi_coils"])
    for ci, (c1, c2) in enumerate(
        zip(actual_meas["hpi_coils"], desired_meas["hpi_coils"])
    ):
        corr = np.abs(np.corrcoef(c1["slopes"].ravel(), c2["slopes"].ravel())[0, 1])
        assert lim[0] <= corr <= lim[1], f"meas['hpi_coils'][{ci}] corr: {corr}"


@pytest.mark.slowtest
@testing.requires_testing_data
def test_refit_hpi_locs_basic():
    """Test that HPI locations can be refit."""
    raw = read_raw_fif(chpi_fif_fname, allow_maxshield="yes").crop(0, 2).load_data()
    # These should be similar (and both should work)
    locs = compute_chpi_amplitudes(raw, t_step_min=2, t_window=1)
    locs_2 = compute_chpi_amplitudes(raw, t_step_min=2, t_window=1, ext_order=0)
    corr = np.corrcoef(locs["slopes"].ravel(), locs_2["slopes"].ravel())[0, 1]
    assert 0.999 < corr < 1.0
    info = raw.info
    del raw

    # Refit on these data won't change much
    info_new = info.copy()
    assert len(info["hpi_results"][-1]["used"]) == 3
    refit_hpi(info_new, amplitudes=False, locs=False, use=3)
    assert len(info_new["hpi_results"]) == len(info["hpi_results"]) == 1
    assert len(info_new["hpi_meas"]) == len(info["hpi_meas"]) == 1
    assert_trans_allclose(
        info_new["dev_head_t"],
        info["dev_head_t"],
        dist_tol=0.1e-3,
        angle_tol=0.1,
    )
    # Refit with more coils than hpifit (our default is use=None)
    refit_hpi(info_new, amplitudes=False, locs=False)
    assert len(info_new["hpi_results"][-1]["used"]) == 5
    assert_trans_allclose(
        info_new["dev_head_t"],
        info["dev_head_t"],
        dist_tol=3e-3,
        angle_tol=2,
    )
    # Refit locations
    refit_hpi(info_new, amplitudes=False)  # default: locs=True
    assert_trans_allclose(
        info_new["dev_head_t"],
        info["dev_head_t"],
        dist_tol=2e-3,
        angle_tol=2,
    )
    assert_array_equal(
        info_new["hpi_results"][-1]["order"], info["hpi_results"][-1]["order"]
    )
    assert_slopes_correlated(
        info_new["hpi_meas"][-1],
        info["hpi_meas"][-1],
        lim=(0.999999, 1.0),
    )
    with pytest.raises(ValueError, match="must also be True"):
        refit_hpi(info_new, locs=False)
    # Refit locations and amplitudes (with ext_order=0 just to make sure it works)
    refit_hpi(info_new, ext_order=0)
    assert_trans_allclose(
        info_new["dev_head_t"],
        info["dev_head_t"],
        dist_tol=2e-3,
        angle_tol=2,
    )
    assert_array_equal(
        info_new["hpi_results"][-1]["order"], info["hpi_results"][-1]["order"]
    )
    assert_slopes_correlated(
        info_new["hpi_meas"][-1], info["hpi_meas"][-1], lim=(0.99, 0.999999)
    )


@testing.requires_testing_data
def test_refit_hpi_locs_problematic():
    """Test that we can fix problematic HPI fits."""
    info_bad = read_info(chpi_problem_fname)
    ang, dist = angle_distance_between_rigid(
        info_bad["dev_head_t"]["trans"], angle_units="deg", distance_units="mm"
    )
    assert_allclose(ang, 177, atol=1)  # upside-down!
    assert_allclose(dist, 61, atol=1)
    orig_order = [4, 2, 1, 3, 5]
    good_order = [1, 2, 4, 3, 5]
    assert_array_equal(info_bad["hpi_results"][-1]["order"], orig_order)
    orig_use = info_bad["hpi_results"][-1]["used"]
    assert_array_equal(orig_use, [2, 3, 5])
    with pytest.warns(RuntimeWarning, match="colinear"):
        info_new = refit_hpi(
            info_bad.copy(),
            amplitudes=False,
            locs=False,
            order=False,
            use=orig_use - 1,
            dist_limit=np.inf,
            colinearity_limit=0.03,
        )
    assert_array_equal(info_new["hpi_results"][-1]["order"], orig_order)
    assert_array_equal(info_new["hpi_results"][-1]["used"], orig_use)
    assert_trans_allclose(
        info_new["dev_head_t"],
        info_bad["dev_head_t"],
        dist_tol=1e-3,
        angle_tol=1,
    )
    # Even just allowing our permutation checker to run helps
    with pytest.raises(RuntimeError, match="need at least 3"):
        refit_hpi(info_bad.copy(), amplitudes=False, locs=False, order=False)
    info_new = refit_hpi(
        info_bad.copy(),
        amplitudes=False,
        locs=False,
        dist_limit=0.02,
        colinearity_limit=0.03,
    )
    assert_array_equal(info_new["hpi_results"][-1]["order"], good_order)
    ang, dist = angle_distance_between_rigid(
        info_new["dev_head_t"]["trans"],
        angle_units="deg",
        distance_units="mm",
    )
    assert 10 < ang < 15  # much more upright!
    assert 75 < dist < 80
    with pytest.warns(RuntimeWarning, match="Discrepancy"):
        # We can run this with amplitudes=True, but it's much faster not to
        # (and the result is very similar)
        info_new = refit_hpi(
            info_bad.copy(), amplitudes=False, dist_limit=0.01, colinearity_limit=0.03
        )
    assert_array_equal(info_new["hpi_results"][-1]["order"], good_order)
    ang, dist = angle_distance_between_rigid(
        info_new["dev_head_t"]["trans"],
        angle_units="deg",
        distance_units="mm",
    )
    assert 3 < ang < 6
    assert 82 < dist < 87
